from PyQt5 import QtGui, QtCore
from PyQt5.QtCore import QEvent, Qt, QLineF, QPoint, QPointF
from PyQt5.QtGui import QPainter, QColor
from PyQt5.QtWidgets import QWidget

from structures.basic_structure import BasicStructure
from view.layers.base_canvas_layer import BaseLayer


class RecordLayer(BaseLayer):
    LAYER_NAME = "Record"

    def __init__(self, ls: BasicStructure):
        super().__init__(ls)

        self._record_pen = QtGui.QPen(QtCore.Qt.GlobalColor.red, 1, QtCore.Qt.PenStyle.SolidLine)
        self._record_pen_dash = QtGui.QPen(QtCore.Qt.GlobalColor.magenta, 2, QtCore.Qt.PenStyle.DashLine)
        self._record_keypoint_pen = QtGui.QPen(QtCore.Qt.GlobalColor.blue, 3, QtCore.Qt.PenStyle.SolidLine)
        self._record_unreachable_pen = QtGui.QPen(QColor(26, 235, 255), 2, QtCore.Qt.PenStyle.SolidLine)

        self.record_points_list = []
        self._record_points = []
        self._outer_points = []
        self.init_record()

        self.__cursor_point = None

    def init_record(self, init_points_list=None):
        if init_points_list is None:
            self.record_points_list = []
        else:
            print("----------load-----------")
            print(init_points_list)
            final_points_list = []
            for key_points in init_points_list:
                # 过滤掉points_list中长度为0的元素
                if len(key_points) == 0:
                    continue
                # 将所有子元素乘以 ratio 得到最终用于显示的点
                key_points = [self._ls.forward_ratio(point) for point in key_points]

                final_points_list.append(key_points)

            self.record_points_list = final_points_list

        self._record_points = []
        self.record_points_list.append(self._record_points)

        self._outer_points = []

    def update_last_record_traj_horizontal(self):
        for points in reversed(self.record_points_list):
            if len(points) == 0:
                continue
            # 将 points 所有的x值取反


    def check_all_points_in_reachable_range(self):
        self._outer_points = []
        for points in self.record_points_list:
            self._outer_points.extend([p for p in points if not self._ls.try_update_point_end(p)])

        return len(self._outer_points) == 0

    def handle_drawing_event(self, event, x, y):
        cursor = [x, y]
        self.__cursor_point = cursor

        if event.type() == QEvent.MouseButtonPress:

            if event.button() == Qt.LeftButton:
                if self.is_close_to_last():
                    return False
                self._append_point(cursor)
                return True

            elif event.button() == Qt.RightButton:
                return self.remove_last_point()
            elif event.button() == Qt.MiddleButton:
                if len(self._record_points) > 2:
                    self._append_point(self._record_points[0])
                    return self.record_to_list()

        elif event.type() == QEvent.MouseButtonDblClick:
            if event.button() == Qt.LeftButton:
                return self.record_to_list()
            elif event.button() == Qt.RightButton:
                return self.remove_last_point()

        elif event.type() == QEvent.MouseMove:
            if len(self._record_points) == 0:
                return False

            return True

        return False

    def _append_point(self, cursor):
        self._record_points.append(cursor)

    def is_close_to_last(self):
        if len(self._record_points) == 0:
            return False
        last_p = self._record_points[-1]
        curr_p = self.__cursor_point
        return abs(curr_p[0] - last_p[0]) < 3 and abs(curr_p[1] - last_p[1]) < 3

    def remove_last_point(self):
        if len(self._record_points) > 0:
            self._record_points.pop()
        else:
            if len(self.record_points_list) > 1:
                self.record_points_list.pop()

            self._record_points = self.record_points_list[-1]
        return True

    def record_to_list(self):
        if len(self._record_points) == 0:
            return False

        self._record_points = []
        self.record_points_list.append(self._record_points)
        return True

    def render(self, qp: QPainter, widget: QWidget):

        for points in self.record_points_list:
            # 绘制鼠标追踪路径
            record_points_count = len(points)
            if record_points_count == 0:
                continue

            qp.setPen(self._record_pen)

            # 绘制连续多个线段
            if record_points_count > 1:
                lines = [QLineF(*points[i], *points[i + 1]) for i in range(record_points_count - 1)]
                qp.drawLines(lines)

            # 绘制多个点
            qp.setPen(self._record_keypoint_pen)
            qp.drawPoints(*[QPointF(p[0], p[1]) for p in points])

        if self.__cursor_point is not None:
            points = self.record_points_list[-1]
            if len(points) > 0:
                # 绘制最后一个点到鼠标的虚线
                qp.setPen(self._record_pen_dash)
                qp.drawLine(QPointF(*points[-1]), QPointF(*self.__cursor_point))

        # 绘制那些不可达的点
        if len(self._outer_points) > 0:
            qp.setPen(self._record_unreachable_pen)
            qp.setBrush(Qt.NoBrush)
            for p in self._outer_points:
                qp.drawEllipse(QPoint(round(p[0]), round(p[1])), 5, 5)

